{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xrm1Sci01iDL"
   },
   "source": [
    " # Import library"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "dE9BtQhU1p0P"
   },
   "outputs": [],
   "source": [
    "import timm\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt                        \n",
    "import torch\n",
    "import torchvision.models as models\n",
    "import torch.nn as nn\n",
    "import os,sys\n",
    "import scipy.io as sio\n",
    "import pdb\n",
    "from time import time\n",
    "from torch.utils.data import DataLoader\n",
    "from sklearn.model_selection import train_test_split\n",
    "from torch.utils.data import Dataset\n",
    "from PIL import Image\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import Dataset\n",
    "import random\n",
    "import torchvision.transforms.functional as TF\n",
    "from tqdm.notebook import tqdm\n",
    "from scipy import spatial\n",
    "from scipy.special import softmax\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GV0vVi0J2fse"
   },
   "source": [
    "# Load Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "H7_vnkmG9JZ8"
   },
   "source": [
    "### Downloading Images \n",
    "It is recommended to download images for the desired datasets before continue running the code\n",
    "\n",
    "Images can be downloaded via the following links:\n",
    "\n",
    "\n",
    "**AWA2**: https://cvml.ist.ac.at/AwA2/AwA2-data.zip\n",
    "\n",
    "\n",
    "**CUB**: http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz\n",
    "\n",
    "\n",
    "**SUN**: http://cs.brown.edu/~gmpatter/Attributes/SUNAttributeDB_Images.tar.gz\n",
    "\n",
    "*Refer to the attached .txt file named as \"Dataset_Instruction\" for more information*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ynSJKlAU8YK0"
   },
   "source": [
    "### Downloading Attributes\n",
    "For more information, refer to https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/research/zero-shot-learning/zero-shot-learning-the-good-the-bad-and-the-ugly \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DyXA4-t39I-D"
   },
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "SxYs_zC28jMi"
   },
   "outputs": [],
   "source": [
    "%%bash\n",
    "if [ -d \"./data\" ] \n",
    "then\n",
    "    echo \"Files are already there.\"\n",
    "else\n",
    "    wget -q \"http://datasets.d2.mpi-inf.mpg.de/xian/xlsa17.zip\"\n",
    "    unzip -q xlsa17.zip -d ./data\n",
    "fi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HoUxvCLR4l5l"
   },
   "source": [
    "### Choose the Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MiXo9Mcx2s-o"
   },
   "outputs": [],
   "source": [
    "DATASET = 'CUB' # [\"AWA2\", \"CUB\", \"SUN\"]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6g8SBGx-4pgH"
   },
   "source": [
    "Set Dataset Paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "pa4OpphZ6N_l"
   },
   "outputs": [],
   "source": [
    "if DATASET == 'AWA2':\n",
    "  ROOT='./data/AWA2/Animals_with_Attributes2/JPEGImages/'\n",
    "elif DATASET == 'CUB':\n",
    "  ROOT='./data/CUB/CUB_200_2011/CUB_200_2011/images/'\n",
    "elif DATASET == 'SUN':\n",
    "  ROOT='./data/SUN/images/'\n",
    "else:\n",
    "  print(\"Please specify the dataset\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "UNizCSl24p8c"
   },
   "outputs": [],
   "source": [
    "DATA_DIR = f'./data/xlsa17/data/{DATASET}'\n",
    "data = sio.loadmat(f'{DATA_DIR}/res101.mat') \n",
    "# data consists of files names \n",
    "attrs_mat = sio.loadmat(f'{DATA_DIR}/att_splits.mat')\n",
    "# attrs_mat is the attributes (class-level information)\n",
    "image_files = data['image_files']\n",
    "\n",
    "if DATASET == 'AWA2':\n",
    "  image_files = np.array([im_f[0][0].split('JPEGImages/')[-1] for im_f in image_files])\n",
    "else:\n",
    "  image_files = np.array([im_f[0][0].split('images/')[-1] for im_f in image_files])\n",
    "\n",
    "\n",
    "# labels are indexed from 1 as it was done in Matlab, so 1 subtracted for Python\n",
    "labels = data['labels'].squeeze().astype(np.int64) - 1\n",
    "train_idx = attrs_mat['train_loc'].squeeze() - 1\n",
    "val_idx = attrs_mat['val_loc'].squeeze() - 1\n",
    "trainval_idx = attrs_mat['trainval_loc'].squeeze() - 1\n",
    "test_seen_idx = attrs_mat['test_seen_loc'].squeeze() - 1\n",
    "test_unseen_idx = attrs_mat['test_unseen_loc'].squeeze() - 1\n",
    "\n",
    "# consider the train_labels and val_labels\n",
    "train_labels = labels[train_idx]\n",
    "val_labels = labels[val_idx]\n",
    "\n",
    "# split train_idx to train_idx (used for training) and val_seen_idx\n",
    "train_idx, val_seen_idx = train_test_split(train_idx, test_size=0.2, stratify=train_labels)\n",
    "# split val_idx to val_idx (not useful) and val_unseen_idx\n",
    "val_unseen_idx = train_test_split(val_idx, test_size=0.2, stratify=val_labels)[1]\n",
    "# attribute matrix\n",
    "attrs_mat = attrs_mat[\"att\"].astype(np.float32).T\n",
    "\n",
    "### used for validation\n",
    "# train files and labels\n",
    "train_files = image_files[train_idx]\n",
    "train_labels = labels[train_idx]\n",
    "uniq_train_labels, train_labels_based0, counts_train_labels = np.unique(train_labels, return_inverse=True, return_counts=True)\n",
    "# val seen files and labels\n",
    "val_seen_files = image_files[val_seen_idx]\n",
    "val_seen_labels = labels[val_seen_idx]\n",
    "uniq_val_seen_labels = np.unique(val_seen_labels)\n",
    "# val unseen files and labels\n",
    "val_unseen_files = image_files[val_unseen_idx]\n",
    "val_unseen_labels = labels[val_unseen_idx]\n",
    "uniq_val_unseen_labels = np.unique(val_unseen_labels)\n",
    "\n",
    "### used for testing\n",
    "# trainval files and labels\n",
    "trainval_files = image_files[trainval_idx]\n",
    "trainval_labels = labels[trainval_idx]\n",
    "uniq_trainval_labels, trainval_labels_based0, counts_trainval_labels = np.unique(trainval_labels, return_inverse=True, return_counts=True)\n",
    "# test seen files and labels\n",
    "test_seen_files = image_files[test_seen_idx]\n",
    "test_seen_labels = labels[test_seen_idx]\n",
    "uniq_test_seen_labels = np.unique(test_seen_labels)\n",
    "# test unseen files and labels\n",
    "test_unseen_files = image_files[test_unseen_idx]\n",
    "test_unseen_labels = labels[test_unseen_idx]\n",
    "uniq_test_unseen_labels = np.unique(test_unseen_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YvD9KYa7HFDd"
   },
   "source": [
    "# Data Generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DataLoader(Dataset):\n",
    "    def __init__(self, root, image_files, labels, transform=None):\n",
    "        self.root  = root\n",
    "        self.image_files = image_files\n",
    "        self.labels = labels \n",
    "        self.transform = transform\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        # read the iterable image\n",
    "        img_pil = Image.open(os.path.join(self.root, self.image_files[idx])).convert(\"RGB\")\n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img_pil)\n",
    "        # label\n",
    "        label = self.labels[idx]\n",
    "        return img, label\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.image_files)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "k2fvK83IHgc3"
   },
   "source": [
    "# Transformations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "38ODTaNRHtj4"
   },
   "outputs": [],
   "source": [
    "# Training Transformations\n",
    "trainTransform = transforms.Compose([\n",
    "                        transforms.Resize((224, 224)),\n",
    "                        transforms.RandomHorizontalFlip(),\n",
    "                        transforms.ToTensor(),\n",
    "                        transforms.Normalize((0.485, 0.456, 0.406), \n",
    "                                             (0.229, 0.224, 0.225))])\n",
    "# Testing Transformations\n",
    "testTransform = transforms.Compose([\n",
    "                        transforms.Resize((224, 224)),\n",
    "                        transforms.ToTensor(),\n",
    "                        transforms.Normalize((0.485, 0.456, 0.406), \n",
    "                                             (0.229, 0.224, 0.225))])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AFCakFQZH4ew"
   },
   "source": [
    "### Average meter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5Ni1zP7TH7yr"
   },
   "outputs": [],
   "source": [
    "class AverageMeter(object):\n",
    "    \"\"\"Computes and stores the average and current value\"\"\"\n",
    "    def __init__(self):\n",
    "        self.reset()\n",
    "\n",
    "    def reset(self):\n",
    "        self.val = 0\n",
    "        self.avg = 0\n",
    "        self.sum = 0\n",
    "        self.count = 0\n",
    "\n",
    "    def update(self, val, n=1):\n",
    "        self.val = val\n",
    "        self.sum += val * n\n",
    "        self.count += n\n",
    "        self.avg = self.sum / self.count"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "M_bMTKlLIHWY"
   },
   "source": [
    "# Train Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "TPCWpV0pIKVM"
   },
   "outputs": [],
   "source": [
    "def train(model, data_loader, train_attrbs, optimizer, use_cuda, lamb_1=1.0):\n",
    "    \"\"\"returns trained model\"\"\"    \n",
    "    # initialize variables to monitor training and validation loss\n",
    "    loss_meter = AverageMeter()\n",
    "    \"\"\" train the model  \"\"\"\n",
    "    model.train()\n",
    "    tk = tqdm(data_loader, total=int(len(data_loader)))\n",
    "    for batch_idx, (data, label) in enumerate(tk):\n",
    "        # move to GPU\n",
    "        if use_cuda:\n",
    "            data,  label = data.cuda(), label.cuda()\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        x_g = model.vit(data)[0]\n",
    "        # global feature\n",
    "        feat_g = model.mlp_g(x_g)\n",
    "        logit_g = feat_g @ train_attrbs.T\n",
    "        loss = lamb_1 * F.cross_entropy(logit_g, label)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        loss_meter.update(loss.item(), label.shape[0])\n",
    "        tk.set_postfix({\"loss\": loss_meter.avg})\n",
    "        \n",
    "    # print training/validation statistics \n",
    "    print('Train: Average loss: {:.4f}\\n'.format(loss_meter.avg))\n",
    "    \n",
    "\n",
    "def get_reprs(model, data_loader, use_cuda):\n",
    "    model.eval()\n",
    "    reprs = []\n",
    "    for _, (data, _) in enumerate(data_loader):\n",
    "        if use_cuda:\n",
    "            data = data.cuda()\n",
    "        with torch.no_grad():\n",
    "            # only take the global feature\n",
    "            feat = model.vit(data)[0]\n",
    "            feat = model.mlp_g(feat)\n",
    "        reprs.append(feat.cpu().data.numpy())\n",
    "    reprs = np.concatenate(reprs, 0)\n",
    "    return reprs\n",
    "\n",
    "def compute_accuracy(pred_labels, true_labels, labels):\n",
    "    acc_per_class = np.zeros(labels.shape[0])\n",
    "    for i in range(labels.shape[0]):\n",
    "        idx = (true_labels == labels[i])\n",
    "        acc_per_class[i] = np.sum(pred_labels[idx] == true_labels[idx]) / np.sum(idx)\n",
    "    return np.mean(acc_per_class)\n",
    "\n",
    "def validation(model, seen_loader, seen_labels, unseen_loader, unseen_labels, attrs_mat, use_cuda, gamma=None):\n",
    "    # Representation\n",
    "    with torch.no_grad():\n",
    "        seen_reprs = get_reprs(model, seen_loader, use_cuda)\n",
    "        unseen_reprs = get_reprs(model, unseen_loader, use_cuda)\n",
    "\n",
    "    # Labels\n",
    "    uniq_labels = np.unique(np.concatenate([seen_labels, unseen_labels]))\n",
    "    updated_seen_labels = np.searchsorted(uniq_labels, seen_labels)\n",
    "    uniq_updated_seen_labels = np.unique(updated_seen_labels)\n",
    "    updated_unseen_labels = np.searchsorted(uniq_labels, unseen_labels)\n",
    "    uniq_updated_unseen_labels = np.unique(updated_unseen_labels)\n",
    "    uniq_updated_labels = np.unique(np.concatenate([updated_seen_labels, updated_unseen_labels]))\n",
    "\n",
    "    # truncate the attribute matrix\n",
    "    trunc_attrs_mat = attrs_mat[uniq_labels]\n",
    "  \n",
    "    #### ZSL ####\n",
    "    zsl_unseen_sim = unseen_reprs @ trunc_attrs_mat[uniq_updated_unseen_labels].T\n",
    "    pred_labels = np.argmax(zsl_unseen_sim, axis=1)\n",
    "    zsl_unseen_predict_labels = uniq_updated_unseen_labels[pred_labels]\n",
    "    zsl_unseen_acc = compute_accuracy(zsl_unseen_predict_labels, updated_unseen_labels, uniq_updated_unseen_labels)\n",
    "    \n",
    "    #### GZSL ####\n",
    "    # seen classes\n",
    "    gzsl_seen_sim = softmax(seen_reprs @ trunc_attrs_mat.T, axis=1)\n",
    "    # unseen classes\n",
    "    gzsl_unseen_sim = softmax(unseen_reprs @ trunc_attrs_mat.T, axis=1)\n",
    "\n",
    "    gammas = np.arange(0.0, 1.1, 0.1)\n",
    "    gamma_opt = 0\n",
    "    H_max = 0\n",
    "    gzsl_seen_acc_max = 0\n",
    "    gzsl_unseen_acc_max = 0\n",
    "    # Calibrated stacking\n",
    "    for igamma in range(gammas.shape[0]):\n",
    "        # Calibrated stacking\n",
    "        gamma = gammas[igamma]\n",
    "        gamma_mat = np.zeros(trunc_attrs_mat.shape[0])\n",
    "        gamma_mat[uniq_updated_seen_labels] = gamma\n",
    "\n",
    "        gzsl_seen_pred_labels = np.argmax(gzsl_seen_sim - gamma_mat, axis=1)\n",
    "        # gzsl_seen_predict_labels = uniq_updated_labels[pred_seen_labels]\n",
    "        gzsl_seen_acc = compute_accuracy(gzsl_seen_pred_labels, updated_seen_labels, uniq_updated_seen_labels)\n",
    "\n",
    "        gzsl_unseen_pred_labels = np.argmax(gzsl_unseen_sim - gamma_mat, axis=1)\n",
    "        # gzsl_unseen_predict_labels = uniq_updated_labels[pred_unseen_labels]\n",
    "        gzsl_unseen_acc = compute_accuracy(gzsl_unseen_pred_labels, updated_unseen_labels, uniq_updated_unseen_labels)\n",
    "\n",
    "        H = 2 * gzsl_seen_acc * gzsl_unseen_acc / (gzsl_seen_acc + gzsl_unseen_acc)\n",
    "\n",
    "        if H > H_max:\n",
    "            gzsl_seen_acc_max = gzsl_seen_acc\n",
    "            gzsl_unseen_acc_max = gzsl_unseen_acc\n",
    "            H_max = H\n",
    "            gamma_opt = gamma\n",
    "\n",
    "    print('ZSL: averaged per-class accuracy: {0:.2f}'.format(zsl_unseen_acc * 100))\n",
    "    print('GZSL Seen: averaged per-class accuracy: {0:.2f}'.format(gzsl_seen_acc_max * 100))\n",
    "    print('GZSL Unseen: averaged per-class accuracy: {0:.2f}'.format(gzsl_unseen_acc_max * 100))\n",
    "    print('GZSL: harmonic mean (H): {0:.2f}'.format(H_max * 100))\n",
    "    print('GZSL: gamma: {0:.2f}'.format(gamma_opt))\n",
    "\n",
    "    return gamma_opt\n",
    "\n",
    "\n",
    "def test(model, test_seen_loader, test_seen_labels, test_unseen_loader, test_unseen_labels, attrs_mat, use_cuda, gamma):\n",
    "    # Representation\n",
    "    with torch.no_grad():\n",
    "        seen_reprs = get_reprs(model, test_seen_loader, use_cuda)\n",
    "        unseen_reprs = get_reprs(model, test_unseen_loader, use_cuda)\n",
    "    # Labels\n",
    "    uniq_test_seen_labels = np.unique(test_seen_labels)\n",
    "    uniq_test_unseen_labels = np.unique(test_unseen_labels)\n",
    "\n",
    "    # ZSL\n",
    "    zsl_unseen_sim = unseen_reprs @ attrs_mat[uniq_test_unseen_labels].T\n",
    "    predict_labels = np.argmax(zsl_unseen_sim, axis=1)\n",
    "    zsl_unseen_predict_labels = uniq_test_unseen_labels[predict_labels]\n",
    "    zsl_unseen_acc = compute_accuracy(zsl_unseen_predict_labels, test_unseen_labels, uniq_test_unseen_labels)\n",
    "\n",
    "    # Calibrated stacking\n",
    "    Cs_mat = np.zeros(attrs_mat.shape[0])\n",
    "    Cs_mat[uniq_test_seen_labels] = gamma\n",
    "\n",
    "    # GZSL\n",
    "    # seen classes\n",
    "    gzsl_seen_sim = softmax(seen_reprs @ attrs_mat.T, axis=1) - Cs_mat\n",
    "    gzsl_seen_predict_labels = np.argmax(gzsl_seen_sim, axis=1)\n",
    "    gzsl_seen_acc = compute_accuracy(gzsl_seen_predict_labels, test_seen_labels, uniq_test_seen_labels)\n",
    "    \n",
    "    # unseen classes\n",
    "    gzsl_unseen_sim = softmax(unseen_reprs @ attrs_mat.T, axis=1) - Cs_mat\n",
    "    gzsl_unseen_predict_labels = np.argmax(gzsl_unseen_sim, axis=1)\n",
    "    gzsl_unseen_acc = compute_accuracy(gzsl_unseen_predict_labels, test_unseen_labels, uniq_test_unseen_labels)\n",
    "\n",
    "    H = 2 * gzsl_unseen_acc * gzsl_seen_acc / (gzsl_unseen_acc + gzsl_seen_acc)\n",
    "\n",
    "    print('ZSL: averaged per-class accuracy: {0:.2f}'.format(zsl_unseen_acc * 100))\n",
    "    print('GZSL Seen: averaged per-class accuracy: {0:.2f}'.format(gzsl_seen_acc * 100))\n",
    "    print('GZSL Unseen: averaged per-class accuracy: {0:.2f}'.format(gzsl_unseen_acc * 100))\n",
    "    print('GZSL: harmonic mean (H): {0:.2f}'.format(H * 100))\n",
    "    print('GZSL: gamma: {0:.2f}'.format(gamma))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "U6Ch_i62IxQT"
   },
   "source": [
    "# Data Loaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_workers = 4\n",
    "### used in validation\n",
    "# train data loader\n",
    "train_data = DataLoader(ROOT, train_files, train_labels_based0, transform=trainTransform)\n",
    "weights_ = 1. / counts_train_labels\n",
    "weights = weights_[train_labels_based0]\n",
    "train_sampler = torch.utils.data.WeightedRandomSampler(weights, num_samples=train_labels_based0.shape[0], replacement=True)\n",
    "train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=32, sampler=train_sampler, num_workers=num_workers)\n",
    "# seen val data loader\n",
    "val_seen_data = DataLoader(ROOT, val_seen_files, val_seen_labels, transform=testTransform)\n",
    "val_seen_data_loader = torch.utils.data.DataLoader(val_seen_data, batch_size=256, shuffle=False, num_workers=num_workers)\n",
    "# unseen val data loader\n",
    "val_unseen_data = DataLoader(ROOT, val_unseen_files, val_unseen_labels, transform=testTransform)\n",
    "val_unseen_data_loader = torch.utils.data.DataLoader(val_unseen_data, batch_size=256, shuffle=False, num_workers=num_workers)\n",
    "\n",
    "### used in testing\n",
    "# trainval data loader\n",
    "trainval_data = DataLoader(ROOT, trainval_files, trainval_labels_based0, transform=trainTransform)\n",
    "weights_ = 1. / counts_trainval_labels\n",
    "weights = weights_[trainval_labels_based0]\n",
    "trainval_sampler = torch.utils.data.WeightedRandomSampler(weights, num_samples=trainval_labels_based0.shape[0], replacement=True)\n",
    "trainval_data_loader = torch.utils.data.DataLoader(trainval_data, batch_size=32, sampler=trainval_sampler, num_workers=num_workers)\n",
    "# seen test data loader\n",
    "test_seen_data = DataLoader(ROOT, test_seen_files, test_seen_labels, transform=testTransform)\n",
    "test_seen_data_loader = torch.utils.data.DataLoader(test_seen_data, batch_size=256, shuffle=False, num_workers=num_workers)\n",
    "# unseen test data loader\n",
    "test_unseen_data = DataLoader(ROOT, test_unseen_files, test_unseen_labels, transform=testTransform)\n",
    "test_unseen_data_loader = torch.utils.data.DataLoader(test_unseen_data, batch_size=256, shuffle=False, num_workers=num_workers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8F2lgZOAI24k"
   },
   "source": [
    "# Baseline Model (ViT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3W08lH9FI8O-"
   },
   "outputs": [],
   "source": [
    "class ViT(nn.Module):\n",
    "    def __init__(self, model_name=\"vit_large_patch16_224_in21k\", pretrained=True):\n",
    "        super(ViT, self).__init__()\n",
    "        self.vit = timm.create_model(model_name, pretrained=pretrained)\n",
    "        # Others variants of ViT can be used as well\n",
    "        '''\n",
    "        1 --- 'vit_small_patch16_224'\n",
    "        2 --- 'vit_base_patch16_224'\n",
    "        3 --- 'vit_large_patch16_224',\n",
    "        4 --- 'vit_large_patch32_224'\n",
    "        5 --- 'vit_deit_base_patch16_224'\n",
    "        6 --- 'deit_base_distilled_patch16_224',\n",
    "        '''\n",
    "\n",
    "        # Change the head depending of the dataset used \n",
    "        self.vit.head = nn.Identity()\n",
    "    def forward(self, x):\n",
    "        x = self.vit.patch_embed(x)\n",
    "        cls_token = self.vit.cls_token.expand(x.shape[0], -1, -1)  \n",
    "        if self.vit.dist_token is None:\n",
    "            x = torch.cat((cls_token, x), dim=1)\n",
    "        else:\n",
    "            x = torch.cat((cls_token, self.vit.dist_token.expand(x.shape[0], -1, -1), x), dim=1)\n",
    "        x = self.vit.pos_drop(x + self.vit.pos_embed)\n",
    "        x = self.vit.blocks(x)\n",
    "        x = self.vit.norm(x)\n",
    "        \n",
    "        return x[:, 0], x[:, 1:]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "oB0jr5sMJLdv"
   },
   "source": [
    "# Model and Optimizer Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "V1lLFB0qJNqH"
   },
   "outputs": [],
   "source": [
    "import collections\n",
    "from torch import optim\n",
    "use_cuda = torch.cuda.is_available()\n",
    "\n",
    "if DATASET == 'AWA2':\n",
    "  attr_length = 85\n",
    "elif DATASET == 'CUB':\n",
    "  attr_length = 312\n",
    "elif DATASET == 'SUN':\n",
    "  attr_length = 102\n",
    "else:\n",
    "  print(\"Please specify the dataset, and set {attr_length} equal to the attribute length\")\n",
    "\n",
    "vit = ViT(\"vit_large_patch16_224_in21k\")\n",
    "mlp_g = nn.Linear(1024, attr_length, bias=False)\n",
    "\n",
    "model = nn.ModuleDict({\n",
    "    \"vit\": vit,\n",
    "    \"mlp_g\": mlp_g})\n",
    "\n",
    "# finetune all the parameters\n",
    "for param in model.parameters():\n",
    "    param.requires_grad = True\n",
    "    \n",
    "# move model to GPU if CUDA is available\n",
    "if use_cuda:\n",
    "    model = model.cuda()\n",
    "\n",
    "optimizer = torch.optim.Adam([{\"params\": model.vit.parameters(), \"lr\": 0.00001, \"weight_decay\": 0.0001},\n",
    "                              {\"params\": model.mlp_g.parameters(), \"lr\": 0.001, \"weight_decay\": 0.00001}])\n",
    "                              \n",
    "lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 30, 40], gamma=0.5)\n",
    "#lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.1)\n",
    "\n",
    "\n",
    "# train attributes\n",
    "train_attrbs = attrs_mat[uniq_train_labels]\n",
    "train_attrbs_tensor = torch.from_numpy(train_attrbs)\n",
    "# trainval attributes\n",
    "trainval_attrbs = attrs_mat[uniq_trainval_labels]\n",
    "trainval_attrbs_tensor = torch.from_numpy(trainval_attrbs)\n",
    "if use_cuda:\n",
    "    train_attrbs_tensor = train_attrbs_tensor.cuda()\n",
    "    trainval_attrbs_tensor = trainval_attrbs_tensor.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0YznkklMKV25"
   },
   "source": [
    "# Training and Testing the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pPCFuP9dK3wn"
   },
   "source": [
    "### Setting the calibration factor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "1mPTYAvVKXf9"
   },
   "outputs": [],
   "source": [
    "\"\"\" Only Run this cell if you are to tune the calibration factor (gamma)\n",
    "    It is data-dependent, and decided based on the validation set \"\"\"\n",
    "gammas = []\n",
    "for i in range(20):\n",
    "    train(model, train_data_loader, train_attrbs_tensor, optimizer, use_cuda, lamb_1=1.0)\n",
    "    lr_scheduler.step()\n",
    "    gamma = validation(model, val_seen_data_loader, val_seen_labels, val_unseen_data_loader, val_unseen_labels, attrs_mat, use_cuda)\n",
    "    gammas.append(gamma)\n",
    "gamma = np.mean(gammas)\n",
    "print(gamma)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "s8lAQPtcLttQ"
   },
   "source": [
    "### Calibration factor is Set\n",
    "It is 0.9 for AWA2 and CUB\n",
    "0.4 for SUN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "1Q8fdI_8L9pN"
   },
   "outputs": [],
   "source": [
    "if DATASET == 'AWA2':\n",
    "  gamma = 0.9\n",
    "elif DATASET == 'CUB':\n",
    "  gamma = 0.9\n",
    "elif DATASET == 'SUN':\n",
    "  gamma = 0.4\n",
    "else:\n",
    "  print(\"Please specify the dataset, and set {attr_length} equal to the attribute length\")\n",
    "print('Dataset:', DATASET, '\\nGamma:',gamma)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for i in range(80):\n",
    "    train(model, trainval_data_loader, trainval_attrbs_tensor, optimizer, use_cuda, lamb_1=1.0)\n",
    "    print(' .... Saving model ...')\n",
    "    print('Epoch: ', i)\n",
    "    save_path= str(DATASET) + '__ViT-ZSL__' +'Epoch_' + str(i) + '.pt'\n",
    "    ckpt_path = './checkpoint/' + str(DATASET)\n",
    "    path = os.path.join(ckpt_path, save_path)\n",
    "    torch.save(model.state_dict(), path)\n",
    "\n",
    "    lr_scheduler.step()\n",
    "    test(model, test_seen_data_loader, test_seen_labels, test_unseen_data_loader, test_unseen_labels, attrs_mat, use_cuda, gamma)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [
    "Nr2V4fFG5MSw",
    "ynSJKlAU8YK0",
    "aceqa0YpGrpj",
    "YvD9KYa7HFDd",
    "AFCakFQZH4ew",
    "M_bMTKlLIHWY",
    "8F2lgZOAI24k"
   ],
   "name": "Explicit_AND_Implicit_Attention.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
